import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from matplotlib import collections as matcoll
import os
import math

# This file implements a plotting style using Matplotlib and seaborn,
# and actually makes plots of Local SGD variants' loss versus communication rounds

###################################################################################################
# Tweaking seaborn to make our curves more beautiful :)
# Seaborn allows us to actually change matplotlob parameters through it
# Inspired by: https://towardsdatascience.com/making-matplotlib-beautiful-by-default-d0d41e3534fd

sns.set(font='Franklin Gothic Book',
        rc={
            'axes.axisbelow': False,
            'axes.edgecolor': 'lightgrey',
            'axes.facecolor': 'None',
            'axes.grid': False,
            'axes.labelcolor': 'dimgrey',
            'axes.spines.right': False,
            'axes.spines.top': False,
            'figure.facecolor': 'white',
            'lines.solid_capstyle': 'round',
            'patch.edgecolor': 'w',
            'patch.force_edgecolor': True,
            'text.color': 'black',
            'xtick.bottom': False,
            'xtick.color': 'dimgrey',
            'xtick.direction': 'out',
            'xtick.top': False,
            'ytick.color': 'dimgrey',
            'ytick.direction': 'out',
            'ytick.left': False,
            'ytick.right': False})

# setting some global font sizes
sns.set_context("notebook", rc={"font.size": 16,
                                "axes.titlesize": 18,
                                "axes.labelsize": 18})

# Defining colour names
CB91_Blue = '#2CBDFE'
CB91_Green = '#47DBCD'
CB91_Pink = '#F3A0F2'
CB91_Purple = '#9D2EC5'
CB91_Violet = '#661D98'
CB91_Amber = '#F5B14C'
CB91_Black = '#000000'

# Setting default colour for plotting and cycling through them
color_list = [CB91_Blue, CB91_Black, CB91_Green,
              CB91_Purple, CB91_Black, CB91_Amber, CB91_Violet]
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=color_list)
plt.rcParams.update({'lines.markeredgewidth': 1})

###########################################################################################

homogeneity = 0.1

lstorm = {"16":None, "32":None, "64":None, "128":None, "256":None, "512":None, "1024":None}

for K in lstorm.keys():
    path = f"../results/03/fc/cifar10/homogeneity={homogeneity}/lstorm_K={K}_b=16/"
    best = np.inf
    for file in os.listdir(path):
        history = np.load(path+str(file)+"/seed=0.pickle", allow_pickle=True)['train_loss']
        if len(history)>0 and np.min(history) < best:
            best = np.min(history)
            lstorm[K] = history

lsarah = {"16":None, "32":None, "64":None, "128":None, "256":None, "512":None, "1024":None}
for K in ["16", "32", "64", "128"]:
    path = f"../results/01/fc/cifar10/homogeneity={homogeneity}/lsarah_K={K}_b=16/"
    best = np.inf
    for file in os.listdir(path):
        history = np.load(path+str(file)+"/seed=0.pickle", allow_pickle=True)['train_loss']
        if len(history)>0 and np.min(history) < best:
            best = np.min(history)
            lsarah[K] = history

for K in ["256", "512", "1024"]:
    path = f"../results/03/fc/cifar10/homogeneity={homogeneity}/lsarah_K={K}_b=16/"
    best = np.inf
    for file in os.listdir(path):
        history = np.load(path+str(file)+"/seed=0.pickle", allow_pickle=True)['train_loss']
        if len(history)>0 and np.min(history) < best:
            best = np.min(history)
            lsarah[K] = history
            
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# fig.suptitle('CE-LSGD v/s BVR-LSGD in the Intermittent Communication Setting')            
for K in ["64", "1024"]:
    if K=="16":
        i=0
        j=0
    elif K=="64":
        i=0
        j=1
    elif K=="256":
        i=1
        j=0
    else:
        i=1
        j=1
    OUR = lstorm[K]
    BVR = lsarah[K]
    T_BVR = math.floor(1+5000.0/(16*int(K)))
    rho_BVR = math.ceil(5000.0 / (16*int(K)))
    x_BVR = 10 * ( 2 + rho_BVR/T_BVR ) * (np.array(range(len(BVR))) + 1)
    x_OUR = 20 * (np.array(range(len(OUR))) + 1)
    
    axes[i].plot(np.log10(x_BVR), np.log10(BVR), label="BVR-LSGD", linewidth=3)
    axes[i].plot(np.log10(x_OUR), np.log10(OUR), label=r"CE-LSGD", linewidth=3)
    
    axes[i].set(title=f"$K={K}$", xlabel= "$\log_{10}$(Number of Communication Rounds)", ylabel="$\log_{10}$(Train Loss)")
    if i==0:
        axes[i].legend(prop={'size': 20})
    i+=1
    

plt.tight_layout()    
plt.savefig(f"figs/intermittent.png", dpi=150)
            

            